package it.uniroma2.util.math.statistical;

import java.util.*;
import java.io.*;


public class ROCanalyzer
{

public float A = 0;	//AROC area

boolean verbose = false;

private class point{
		protected float FPrate_;	// 1-specificity coordinate (x-axis)
		protected float TPrate_;	//sensitivity coordinate (y-axis)
		protected float prec_;
		protected float rec_;
		protected float fMeasure_;
		protected float acc_;
		
		protected point(int TP_,int FP_,int TN_,int FN_, int P_, int N_){
			float TP =  (new Float(TP_)).floatValue();
			float FP =  (new Float(FP_)).floatValue();
			float TN =  (new Float(TN_)).floatValue();
			float FN =  (new Float(FN_)).floatValue();
			float P =  (new Float(P_)).floatValue();
			float N =  (new Float(N_)).floatValue();
			
			FPrate_ = FP/N;
			TPrate_ = TP/P;
			prec_ = TP/(TP+FP);
			rec_ = TPrate_;
			fMeasure_= (2*rec_*prec_)/(rec_+prec_);
			acc_ = (TP+TN)/(P+N) ;
			}
		
		protected String toStringa(){
			return FPrate_ + "\t" + TPrate_ + "\t" + prec_ + "\t" + rec_ + "\t" + fMeasure_ + "\t" + acc_ ;
			}
}//END CLASS

private class ROCdata{
		
		protected int id_;				//progressive id of the instace
		//protected String data_;			//describes the instance (eg. "win-play" for a verb couple)
		protected boolean positive_; 	//true if sample is from true set
		protected float f_;				//function f to which ROC must be applied
		
		protected ROCdata(String in, int id, boolean positive){
			try 
				{
				StringTokenizer st= new StringTokenizer(in,"\t");
				//id_ = (new Integer(st.nextToken())).intValue();
				id_=id;
				//data_ = st.nextToken();
				f_ = (new Float(st.nextToken())).floatValue();
				positive_ = positive; 
				}catch (Exception e){
					System.out.println("Incorrect file format in line: <" + in +">\nFormat must be <verb-couple>\t<function value>");
					e.printStackTrace(); 
					System.exit(-1);
							}
			}
		
}//END CLASS	
		


// Receives as input a TRUE and a FALSE set for ROC analysis. It prints on the
// output files the points of the ROC curve together with the AROC area.
// Input files must be in the format:
//	<function value>
//(e.g. 		0.3424)
//
//Output file is in the format, for each point:
//<FP rate>	<TP rate>	<precision>	<recall>	<f-Measure>	<accuracy>
//
//where FP rate is 1-specificity and TP rate is sensitivity
public static void main(String[] argv) throws Exception {
	if (argv.length!=3)
		{
		System.out.println("Arguments:\n\t1) trueSet file\n\t2) controlSet\n\t3) outFile\n");
		System.exit(-1);
		}
	else 
		{
		ROCanalyzer ra = new ROCanalyzer();	
		ra.Run(argv[0],argv[1],argv[2]);
		}	
}


//reads the input data from the files storing them in a unique vector. Sorts 
//the vector in decreasing value of the function. Starts the ROC evaluation.
//Prints put results on the output file.
public void Run (String trueFile,String controlFile, String outF)throws Exception{
	
	BufferedReader brT = new BufferedReader(new FileReader(trueFile));
	BufferedReader brC = new BufferedReader(new FileReader(controlFile));
	BufferedWriter outFile = new BufferedWriter(new FileWriter(outF));
	Vector<ROCdata> vCouple = new Vector<ROCdata>();
	String inLine="";
	int N=0,P=0; //number of positives and negatives couples (true and control set couples)
	int i=0;	 //id counter
	while ((inLine=brT.readLine())!=null)
		{
		vCouple.add(new ROCdata(inLine,i,true));
		P++;
		i++;
		}
	while ((inLine=brC.readLine())!=null)
		{
		vCouple.add(new ROCdata(inLine,i,false));
		N++;
		i++;
		}
	vCouple = sort(vCouple);
	
	if (verbose) {
	System.out.println("POINTS IN ROC INPUT VECTOR: " + vCouple.size());
	for (int m=0;m<vCouple.size();m++)
		System.out.println(vCouple.elementAt(m).f_);
	}
	
	
	Vector<point> p = ROC(vCouple,P,N);
	outFile.write("------------- ROC CURVE POINTS -------------\n\n");
	outFile.write("FPrate\tTPrate\tprec\trec\tf-meas\taccur\n");
	if (verbose) { 
		System.out.println("POINTS IN ROC CURVE: " + p.size());
		System.out.println("------------- ROC CURVE POINTS -------------\n");		
		System.out.println("FPrate\tTPrate\tprec\trec\tf-meas\taccur\n");
	}
	for (int j=0;j<p.size();j++)
		{
		if (verbose) System.out.println(p.elementAt(j).toStringa());
		outFile.write(p.elementAt(j).toStringa() + "\n");
		}
	if (verbose) System.out.println("\n\n------------- AROC AREAS : \n" + A);
	else System.out.print("" + A);
	outFile.write("\n\n------------- AROC AREAS : " + A +"\n");
	outFile.close();
	brT.close();
	brC.close();
}


//reads the input data from the files storing them in a unique vector. Sorts 
//the vector in decreasing value of the function. Starts the ROC evaluation.
//Prints put results on the output file.
public double aroc(Vector<String> positives,Vector<String> control_elements){
	
	Vector<ROCdata> vCouple = new Vector<ROCdata>();
	int N=0,P=0; //number of positives and negatives couples (true and control set couples)
	int i=0;	 //id counter
	for (String positive:positives)
		{
		vCouple.add(new ROCdata(positive,i,true));
		P++;
		i++;
		}
	for (String control_element:control_elements)
		{
		vCouple.add(new ROCdata(control_element,i,false));
		N++;
		i++;
		}
	vCouple = sort(vCouple);
	
	if (verbose) {
	System.out.println("POINTS IN ROC INPUT VECTOR: " + vCouple.size());
	for (int m=0;m<vCouple.size();m++)
		System.out.println(vCouple.elementAt(m).f_);
	}
	
	
	Vector<point> p = ROC(vCouple,P,N);
	if (verbose) { 
		System.out.println("POINTS IN ROC CURVE: " + p.size());
		System.out.println("------------- ROC CURVE POINTS -------------\n");		
		System.out.println("FPrate\tTPrate\tprec\trec\tf-meas\taccur\n");
	}
	for (int j=0;j<p.size();j++)
		{
		if (verbose) System.out.println(p.elementAt(j).toStringa());
		}
	if (verbose) System.out.println("\n\n------------- AROC AREAS : \n" + A);
	return A;
}


//Utility that sorts a vector of ROC data in decreasing alue of the 
//function.
public Vector<ROCdata> sort (Vector<ROCdata> v){
	
	class ROCcomparator implements java.util.Comparator {
          public int compare(Object o1, Object o2) {
               int ret=-1;
               ROCdata r1 = (ROCdata)o1;
               ROCdata r2 = (ROCdata)o2;
               if (r2.f_ - r1.f_>=0)
               	ret=1;
               	else{
               		if (r2.f_ - r1.f_<0)
               			ret=-1;
               		}
               	
               return ret;
          }
          
     }
	
	Collections.sort(v, new ROCcomparator());
	return v;
}



// ROC-AROC CURVE ALGORITHM
// implemented as in  "ROC Graphs: notes and practical considerations for data
// mining researchers" by Tom Fawcett.
// Takes as input a descreasingly ordered vector of ROCdata. Returns a vector
// of points of the ROC curve. AROC area is saved in global variable A.
public Vector<point> ROC(Vector<ROCdata> vSample, int P , int N ){

int FP=0;
int TP=0;
int TN=0;
int FN=0;
int  FPprev=0;
int  TPprev=0;
A = 0;

//Vector<ROCdata> vSample = new Vector<ROCdata>();
ROCdata sample;
Vector<point> vPoint = new Vector<point>();
float f = -99999;
for (int i=0; i<vSample.size(); i++)
	{
	sample=vSample.elementAt(i);	
	if (sample.f_!=f)
		{
		TN=N-FP;
		FN=P-TP;
		A += trapArea(FP,FPprev,TP,TPprev);
		vPoint.add(new point(TP,FP,TN,FN,P,N));
		f = sample.f_;
		FPprev = FP;
		TPprev = TP;
		}	
	if (sample.positive_)
		TP++;
	 else
	 	FP++;	
	}
TN=N-FP;
FN=P-TP;		
//A += trapArea(1,FPprev,1,TPprev);
A += trapArea(N,FPprev,P,TPprev);
vPoint.add(new point(TP,FP,TN,FN,P,N));
A= A/(N*P);
if (verbose) System.out.println("N:" + N);
if (verbose) System.out.println("P:" + P);
return vPoint;
}


//Calculates the area of a trapezoid, given its point coordinates.
public float trapArea(float X1, float X2,float Y1,float Y2){
	float base = Math.abs(X1-X2);
	float high = (Y1+Y2)/2;
	if (verbose) System.out.println("X1=" + X1 + "\tX2=" + X2 + "\tY1=" + Y1 + "\tY2=" + Y2);
	if (verbose) System.out.println("base=" + base + "\thigh=" + high + "\tarea=" + base*high);
	return base*high;	
}

}